Skip to content

Conversation

@yma11
Copy link
Contributor

@yma11 yma11 commented Oct 26, 2025

Purpose

This PR uses FLASH_ATTN as vision attention backend for xpu platform and actually calls varlen_attention kernel in IPEX by dispatching in flash_attn_varlen_func.

Test Plan

python examples/offline_inference/vision_language.py -m glm-4v and python examples/offline_inference/vision_language.py -m qwen2_5_vl

Test Result

python examples/offline_inference/vision_language.py -m qwen2_5_vl:

Processed prompts: 100%|███████████████████████| 4/4 [00:07<00:00,  1.89s/it, est. speed input: 675.14 toks/s, output: 33.86 toks/s]
--------------------------------------------------
The image depicts a stunning view of the Tokyo Skytree, a tall broadcasting tower located in the Sumida Ward of Tokyo, Japan. The photo is taken from a low angle, looking up towards the tower. The sky is clear and blue, providing a vibrant backdrop. In the foreground, there are cherry blossom trees in
--------------------------------------------------
The image depicts a stunning view of the Tokyo Skytree, a tall broadcasting tower located in the Odaiba district of Tokyo, Japan. The skytree is surrounded by cherry blossom trees in full bloom, creating a picturesque and vibrant scene. The cherry blossoms are in full bloom, with delicate pink petals covering the branches
--------------------------------------------------
The image depicts a stunning view of the Tokyo Skytree, a tall communications and observation tower located in the Odaiba district of Tokyo, Japan. The skytree is surrounded by cherry blossom trees in full bloom, creating a picturesque and vibrant scene. The cherry blossoms are in various stages of bloom, with some branches
--------------------------------------------------
The image depicts a stunning view of the Tokyo Skytree, a tall broadcasting tower located in the Odaiba district of Tokyo, Japan. The tower is surrounded by cherry blossom trees in full bloom, creating a picturesque and vibrant scene. The cherry blossoms are in various stages of bloom, with pink and white petals covering
--------------------------------------------------

@mergify mergify bot added the qwen Related to Qwen models label Oct 26, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully adds support for vision attention on XPU platforms by integrating the IPEX backend. The changes are logical and well-contained. However, there is a significant amount of duplicated code for the IPEX attention implementation in qwen2_vl.py and qwen2_5_vl.py. I've added comments suggesting a refactoring to improve code maintainability. Addressing this will make the codebase cleaner and easier to manage in the future.

Comment on lines 430 to 458
elif self.attn_backend == _Backend.IPEX:
from vllm._ipex_ops import ipex_ops

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

output = torch.empty(q.shape, dtype=q.dtype, device=q.device)
ipex_ops.varlen_attention(
q,
k,
v,
output,
cu_seqlens,
cu_seqlens,
None,
max_seqlen,
max_seqlen,
pdropout=0.0,
softmax_scale=1.0 / (q.shape[-1] ** 0.5),
zero_tensors=False,
is_causal=False,
return_softmax=False,
gen_=None,
window_size_left=-1,
window_size_right=-1,
logits_soft_cap=-1,
)
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code for IPEX attention is identical to the one in vllm/model_executor/models/qwen2_vl.py. To improve maintainability and avoid code duplication, this logic should be refactored into a shared function.

For example, you could create a helper function in a common utility file (e.g., vllm/model_executor/models/vision.py):

from einops import rearrange
import torch
from vllm._ipex_ops import ipex_ops

def ipex_varlen_attention(q, k, v, cu_seqlens, max_seqlen, batch_size):
    q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

    output = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    ipex_ops.varlen_attention(
        q,
        k,
        v,
        output,
        cu_seqlens,
        cu_seqlens,
        None,
        max_seqlen,
        max_seqlen,
        pdropout=0.0,
        softmax_scale=1.0 / (q.shape[-1] ** 0.5),
        zero_tensors=False,
        is_causal=False,
        return_softmax=False,
        gen_=None,
        window_size_left=-1,
        window_size_right=-1,
        logits_soft_cap=-1,
    )
    context_layer = rearrange(
        output, "(b s) h d -> s b (h d)", b=batch_size
    ).contiguous()
    return context_layer

Then, you can call this function from both Qwen2_5_VisionAttention and Qwen2VisionAttention.

Comment on lines 464 to 492
elif self.attn_backend == _Backend.IPEX:
from vllm._ipex_ops import ipex_ops

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

output = torch.empty(q.shape, dtype=q.dtype, device=q.device)
ipex_ops.varlen_attention(
q,
k,
v,
output,
cu_seqlens,
cu_seqlens,
None,
max_seqlen,
max_seqlen,
pdropout=0.0,
softmax_scale=1.0 / (q.shape[-1] ** 0.5),
zero_tensors=False,
is_causal=False,
return_softmax=False,
gen_=None,
window_size_left=-1,
window_size_right=-1,
logits_soft_cap=-1,
)
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This IPEX attention implementation is duplicated in vllm/model_executor/models/qwen2_5_vl.py. To adhere to the DRY (Don't Repeat Yourself) principle and improve code maintainability, please refactor this shared logic into a common function. I've left a more detailed suggestion on the other file.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 118 to 122
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
from vllm.attention.backends.registry import _Backend

return _Backend.IPEX

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Returning IPEX backend breaks non‑Qwen ViT models on XPU

The new get_vit_attn_backend now always returns _Backend.IPEX. A number of existing vision models (e.g. GLM‑4V, Dots OCR, SigLIP2NaViT, Keye) guard their initialization with if self.attn_backend not in {FLASH_ATTN, TORCH_SDPA, XFORMERS, ROCM_AITER_FA}: raise RuntimeError(...) and were written assuming XPU would report TORCH_SDPA. With this change those models now receive _Backend.IPEX and immediately raise an unsupported backend error before any inference can run. Either keep returning TORCH_SDPA here or update every model’s whitelist to include _Backend.IPEX to avoid hard‑failing on XPU.

Useful? React with 👍 / 👎.

Comment on lines 595 to 598
elif (
self.attn_backend == _Backend.TORCH_SDPA
or self.attn_backend == _Backend.IPEX
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif (
self.attn_backend == _Backend.TORCH_SDPA
or self.attn_backend == _Backend.IPEX
):
elif self.attn_backend in (_Backend.TORCH_SDPA, _Backend.IPEX):

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also handle IPEX backend in MultiHeadAttention?

@yma11 yma11 marked this pull request as draft October 27, 2025 01:05
@mergify
Copy link

mergify bot commented Oct 28, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yma11.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 28, 2025
@yma11 yma11 force-pushed the vl-xpu branch 2 times, most recently from 04011b1 to 6d3212b Compare October 28, 2025 11:41
@mergify mergify bot removed the needs-rebase label Oct 28, 2025
@yma11 yma11 changed the title [Multimodal][XPU]Add vision attn backend for xpu platform [Multimodal][XPU]Enable vision attn backend for xpu platform Oct 28, 2025
@mergify
Copy link

mergify bot commented Oct 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yma11.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 30, 2025
@yma11 yma11 marked this pull request as ready for review October 30, 2025 13:42
@mergify mergify bot removed the needs-rebase label Oct 30, 2025
@yma11
Copy link
Contributor Author

yma11 commented Oct 30, 2025

Can you also handle IPEX backend in MultiHeadAttention?

@DarkLight1337 I did some changes, can you help review again? Thanks.

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can ignore the other pre-commit errors

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 30, 2025 14:04
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 30, 2025
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
if attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
elif current_platform.is_xpu():
Copy link
Collaborator

@jikunshang jikunshang Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can avoid this elif branch, use fa_utils.flash_attn_varlen_func in else branch

) -> torch.Tensor:
if is_rocm_aiter:
from aiter import flash_attn_varlen_func
elif current_platform.is_xpu():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

flash_attn_varlen_func = ops.flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from flash_attn import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understand your point but should change the use_upstream_fa=False branch. Updated.

block_table: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
window_size: torch.Tensor | None = None,
softcap: torch.Tensor | None = 0.0,
Copy link
Collaborator

@jikunshang jikunshang Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please double check type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated.

auto-merge was automatically disabled October 31, 2025 03:54

Head branch was pushed to by a user without write access

@jikunshang jikunshang enabled auto-merge (squash) October 31, 2025 03:58
@jikunshang jikunshang merged commit 7e2729b into vllm-project:main Nov 1, 2025
55 checks passed
zhaozuy pushed a commit to zhaozuy/vllm that referenced this pull request Nov 4, 2025
…oject#27525)

Signed-off-by: Yan Ma <[email protected]>
Signed-off-by: Kunshang Ji <[email protected]>
Co-authored-by: Yejing Lai <[email protected]>
Co-authored-by: Guancheng Fu <[email protected]>
Co-authored-by: Kunshang Ji <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants